Modeling single cell trajectory using forward-backward stochastic differential equations

Recent advances in single-cell sequencing technology have provided opportunities for mathematical modeling of dynamic developmental processes at the single-cell level, such as inferring developmental trajectories. Optimal transport has emerged as a promising theoretical framework for this task by computing pairings between cells from different time points. However, optimal transport methods have limitations in capturing nonlinear trajectories, as they are static and can only infer linear paths between endpoints. In contrast, stochastic differential equations (SDEs) offer a dynamic and flexible approach that can model non-linear trajectories, including the shape of the path. Nevertheless, existing SDE methods often rely on numerical approximations that can lead to inaccurate inferences, deviating from true trajectories. To address this challenge, we propose a novel approach combining forward-backward stochastic differential equations (FBSDE) with a refined approximation procedure. Our FBSDE model integrates the forward and backward movements of two SDEs in time, aiming to capture the underlying dynamics of single-cell developmental trajectories. Through comprehensive benchmarking on multiple scRNA-seq datasets, we demonstrate the superior performance of FBSDE compared to other methods, highlighting its efficacy in accurately inferring developmental trajectories.


Introduction
The technology of single-cell RNA sequencing (scRNA-seq) is a revolutionary breakthrough in the study of cellular developmental dynamics [1,2].With the availability of time series scRNA-seq data, lineage tracing and cell differentiation processes have been studied in multiple cell types such as salivary glands, liver cells, lung cells, kidney cells, neuronal cells, and tumor cells [3][4][5][6][7][8].Despite their success, the sequencing process is destructive as the sampled cells are destroyed and their future gene expressions cannot be measured.Several mathematical approaches have emerged to take these time series data as input and recover the true cellular differentiation processes.Earlier modelling approaches include continuous-state hidden Markov models (CSHMM) [9], and a graph-based model Tempora [10].These methods typically cluster cells into a small number of distinct cell types and model the trajectory of these cell types instead of modelling individual cells.These approaches are simpler to model and are easier to interpret but lose cell-level information after reducing a population of cells into discrete cell types.Recently, optimal transport (OT) becomes a state-of-art approach in modelling cell developmental trajectories from single cell RNA-seq data.In particular, Schiebinger [11] described a new framework (Waddington-OT) to identify the pairings between cells from different time points where gene expressions of single cells are assumed to follow a developmental trajectory with the lowest transport energy cost.Waddington-OT uses a joint probability mass function to infer how likely a cell sampled at a previous time point will become another cell sampled at a later time point.The key of the OT model is that it assumes the cell population evolves in an energy-efficient fashion analogous to the optimal transport plan in probability theory.In this context, the energy term is defined as the differences in genome-wide gene expression profile of a population of cells between consecutive time points [11].In contrast to graph-based models, optimal transport does not require clustering cells into discrete cell types.Instead, it models the developmental processes at the single cell level, making it a potentially more powerful approach.
Despite its success and later improvement [12], Waddington-OT still has two limitations.First, it can only infer the endpoints of a path from sampled cells, making generative modelling infeasible because paths starting from an unobserved point cannot be simulated.Second, path or trajectory inference in Waddington-OT is static as it only provides linear interpolation between endpoints.In this context, linear trajectory means the shape of a path is simply a straight line connecting the endpoints.Instead of matching path endpoints with OT, a potential solution is to model path movements with stochastic differential equations (SDEs), which are more adept in dynamic trajectory inferences [13,14].In general, an SDE model considers X t as the gene expression of a cell in a population existing between initial time t = 0 and terminal time t = T; the change in X t at time t is denoted by dX t and equals vdt + σdW t where v and σ represent drift and volatility.The drift term v takes a cell's spatial and temporal information as input and produces the direction in which the cell changes its gene expression as output.The volatility term σ mimics the pure randomness in cell movements by scaling Brownian motions denoted by W t .As a result, SDE is able to parametrize non-linear trajectories and provide feasible generative simulations on a continuous timeline.The volatility term σ in SDE is also equivalent to the entropy regulation in optimal transport [15].In practice, SDE models can be parameterized by neural networks [16,17], which have shown good performance in recovering scRNA trajectories [18].
Single cell sequencing has allowed us to study cell-to-cell interactions, which has not been included in earlier trajectory inference methods [19,20].In biology, it is well-established that cells can influence the development of other cells by modulating their gene expression or other cellular programs.This modulation can occur through the secretion of small molecules or via direct interactions between receptors and ligand molecules on cell surfaces.Interestingly, these cell-to-cell interactions bear resemblance to concepts found in mean field control theory, which aims to optimize the collective behavior of a population while accounting for interactions among individual members [16,17].Drawing inspiration from this, a recent model called TrajectoryNet [18] incorporated mean field control theory to capture interactions among single cells.
In the following, we explore the rationale and limitations of TrajectoryNet, which is the most closely related method to our proposed approach.It serves as a valuable baseline for our method.The model assumes that gene expressions in cells at each time point t = 0, 1, 2, . .., T follow a distribution q t .TrajectoryNet begins with q 0 and generates a sequence of distributions ρ t for t = 1, 2, . .., T based on the SDE model (refer to Eq 1 in Methods).This process can be achieved through simulations, starting with m cells randomly subsampled from time point t = 0 and then simulating subsequent cells at time points t = 1, 2, . .., T using the SDE model.The gene expressions of these simulated cells are denoted by x * t;1 ; x * t;2 ; . . .; x * t;m , and their empirical distribution is denoted as rt .The main optimization problem in TrajectoryNet is similar to a road-building problem, where the objective is to construct a road connecting two points A and B, while passing through multiple intermediate checkpoints.TrajectoryNet aims to find an optimal SDE model that captures the collective displacement of simulated cells, akin to the construction cost of a road.Additionally, TrajectoryNet incorporates interactions among cells, resembling the movement of cars passing and merging on the road.Similar to the road-building problem, one would expect rt ¼ q t for each time point t = 0, 1, 2, . .., T, which can be interpreted as the initial and terminal points of the road and the intermediate checkpoints it must traverse.However, TrajectoryNet only satisfies the initial constraint r0 ¼ q 0 but not the subsequent constraints at t = 1, 2, . .., T. Instead, TrajectoryNet relaxes these constraints by penalizing the divergence between q t and rt .In other words, TrajectoryNet approximates the constraints as rt � q t for t = 1, 2, . .., T. Consequently, TrajectoryNet is less effective due to the lack of exact matching between the simulated distributions rt and the true distributions q t for t = 1, 2, . .., T.
Motivated to overcome these limitations, we propose two key enhancements to improve trajectory inference.Firstly, we enhance the approximation of intermediate constraints by introducing a new penalty term in the optimization problem.Secondly, we enforce the terminal constraint rT ¼ q t by implementing the FBSDE method (refer to Fig 1 for a visual illustration).The proposed enhancements aim to enhance the accuracy of trajectory inference by achieving a closer alignment between the simulated distributions rt and the true distributions q t for each time point t = 1, 2, . .., T. This improved alignment ensures that the inferred trajectories more faithfully capture the underlying dynamics of the system.In the rest of this section, we provide an overview of our innovative approaches and their key components.
To enhance the approximation of intermediate constraints, we introduce a new penalty term as compared to the TrajectoryNet method discussed above [18].To satisfy the conditions rt � q t for all time points t, TrajectoryNet implements two penalty terms.First, it directly penalizes the divergence between p t and rt using the maximum likelihood function.Second, it also introduces a penalty term which is proportional to the sum of distances among k nearest neighbours.To illustrate this, suppose we have n observations and n simulations at each time point t and denote x t,i and x * t;j the gene expressions of ith observed cell and jth simulated cell respectively at time t.The conditions rt � q t imply that there is a large overlap between the locations occupied by the observations and those by the simulations in the gene expression space (see Fig 2a and 2b).At each time point t, for each simulated data point x * t;i (i.e., gene expression of single cells), TrajectoryNet computes its distances to k closest observations among {x t,i }.These soft marginal constraints force the simulated points to be in a close neighbourhood of observations.In other words, it is less likely to have simulations far away from observations.Therefore, this penalty improves the precision of the model.We modified the k-nearest-neighbour distances in TrajectoryNet [18] by additionally computing the distances to k closest simulations among fx * t;j g for each observed point x t,i .This additional term allows more observations to be included in a close neighbourhood of simulated points (see Methods for details).As a result, our model also improves recall as illustrated in Fig 2b .To enforce the exact terminal equality constraint rT ¼ q T , we utilize the forward-backward stochastic differential equations (FBSDE) proposed by Vargas et al [21].The FBSDE model consists of a forward model which goes from initial time to terminal time (0 ! T) and a backward model which goes from terminal time to initial time (T !0) (see Methods for details).The forward model generates training data for the backward model and vice versa.The core concept of FBSDE is that training the two models where the time goes in opposite directions will eventually lead to convergence such that rT ¼ q T is met by the forward model and r0 ¼ q 0 is met by the backward model.A recent study by Bunne et al models cell trajectory with FBSDEs and demonstrates that the model indeed preserves terminal equality constraints rT ¼ q T [22].However, this implementation of the FBSDE model considers only the initial and terminal constraints, and does not model the intermediate time points.To further utilize the information from these observed intermediate time points, we extend the original FBSDE by enforcing soft constraints along all the intermediate time points as discussed above.
We benchmark FBSDE with two other methods Waddington-OT (WOT) and Trajectory-Net on three scRNA data sets: Human Embryonic Stem Cells Dataset [23], Mouse Embryonic Fibroblasts Dataset [11] and Arabidopsis thaliana Stem Cells Dataset [12].We adopt a standard leave-one-out benchmark framework (i.e., removing one time point and trying to recover gene expression of hidden time point from other time points).Our results show that the FBSDE model consistently outperforms TrajectoryNet and Waddington-OT.

Methods
This section is structured as follows.In Mathematical details, we begin by providing a mathematical formulation of the developmental trajectory of a single cell and discussing the optimization problem typically employed in trajectory inference methods based on stochastic differential equations (SDEs).We also emphasize the importance of considering constraints within this optimization problem.Additionally, we present an overview of the TrajectoryNet framework [18], which serves as baseline for comparison.
Next, in Estimation of cost functional, we address the approximation of the cost functional, which represents the optimal value of our optimization problem.The cost functional is defined as the expectation of a stochastic integral.To approximate this integral, we employ time discretization, also known as the Euler-Maruyama method.We further discuss the approximation of probability density functions (PDFs), which play a vital role in SDE-based trajectory inference.By implementing these numerical techniques, we illustrate how the cost functional can be effectively approximated given a finite sample.
In Approximation of intermediate constraints, we propose modifications to the Trajec-toryNet framework to improve the approximation of the intermediate constraints.We provide detailed explanations of how these modifications can be implemented using a finite set of cells, leading to enhanced trajectory inference accuracy.
Finally, in Satisfying terminal constraint, we introduce the Forward-Backward Stochastic Differential Equation (FBSDE) system and highlight its advantages in satisfying the terminal constraint.We implement a numerical algorithm proposed by Vargas et al. [21] for solving the FBSDE system.

Mathematical details
Suppose all the cells live in the space of R d in the finite time frame [0, T] for some constant T > 0. The gene expression profile of a cell at time t is denoted by X t 2 R d , which is a continuous random vector with probability density function q t for t 2 [0, T].
Remark 1.In practical applications, gene expression data is often preprocessed and undergoes dimensional reduction to reduce its dimensionality.Typically, the reduced dimension d is much smaller, commonly set to 2 or 3.In our study, all the datasets were obtained in the form of twodimensional data published by authors cited in our manuscript, who employed dimension reduction techniques such as PHATE and UMAP on their raw data.These techniques were chosen for their effectiveness in preserving the global structure of high-dimensional data.For the selection of appropriate dimensions, we advise users to consider the specific protocols of the chosen dimension reduction technique, integrated with domain-specific knowledge.
We are interested in the single-cell developmental trajectory {X t , t 2 [0, T]}.Let X 0 be an initial gene expression of a cell at time 0, then the trajectory of this cell at time t can be modeled by the stochastic differential equation as described by Zhang et al [12]: where v : R d � ½0; T� is a drift function and v(X s , s) characterizes the shift change caused by X s at time s.The integral R t 0 vðX s ; sÞds represents the cumulative mean shift induced by the trajectories of X s up to time t, which is known as the drift term.The s : R d � ½0; T� is a volatility function and {W t } 0�t�T is a standard Brownian motion.At time s, dW s refers to the change in the Brownian motion W s .The term R t 0 sðX s ; sÞdW s represents the random variation induced by the trajectories of X s up to time t, which is known as the volatility term.
To infer the cell trajectory, we aim to learn the unknown drift function v(X t , t), also called action.In our analysis, we assume the volatility function is constant and set it equal to 1 as the default value.
Remark 2. We advise users to normalize their input data.This normalization step is critical as it minimizes the impact of scaling on the diffusion term, ensuring a more consistent and reliable application of the method across different datasets.The selection of σ will be discussed in Parameter tuning.
Eq 1 characterizes the temporal dynamics of gene expression in an individual cell.In a population consisting of millions of cells, each cell undergoes simultaneous changes in its gene expression.Consequently, an action v governs the evolution of the population distribution over time.Remarkably, given an initial distribution ρ 0 , the sequence of distributions ρ(�, t) 0<t�T can be mathematically represented as a function of the action v and volatility σ by leveraging the well-known Fokker-Plank Equation [24].In this context, ρ(�, t) denotes the probability density function generated by the action v and satisfies r t ¼ rð�; tÞ : R d � ½0; T� !R. We assume that ρ is a continuous function with continuous partial derivatives in the space R d � ½0; T�.
As cells undergo changes in their gene expression over time, there is often an associated energy cost.The optimal action ṽ can be determined by seeking a function that minimizes this cost functional.This prompts the question of whether an optimal strategy exists for all cells in the population to collectively modify their gene expressions in order to minimize the cost functional.In general, most SDE-based inference methods consider minimizing the following cost functional: and α F > 0 is a regularization constant tuned by cross-validation.Since the initial gene expression X 0 of a cell is a random variable with distribution q 0 , the cost functional is defined as the expected cost of a trajectory originating from all possible values of X 0 .The first term in the integral represents the non-interactive cost, where a cell's trajectory is independent of the actions of other cells.The integral of this term is typically referred to as the kinetic energy of the trajectory, representing the amount of energy required to change a cell's gene expression along this trajectory.In practice, the trajectory of one cell is likely to be influenced by other cells, therefore we need the second term in the integral, which is the running cost or interactive cost, representing how a cell is affected by other cells while performing action v.As a result, this term increases as the probability density at X t increases.This concept is similar to the idea of crowdedness or entropy in terms of gene expression.For example, cells with similar gene expressions may compete for resources for growth, leading to trajectories with similar gene expressions having higher costs (i.e., higher ρ(X t , t)).
In trajectory inference settings [11,25], one often observes q t (or more precisely the empirical version of it) at a grid of time points 0 = t 0 , t 1 , t 2 , . .., t K = T.The trajectory inference problem aims to determine the optimal action ṽ such that the population evolves from q 0 to q 1 , q 2 , and ultimately to q T , while minimizing the cost functional specified in Eq 2.
Let {ρ t : t 2 (0, T]}, denote a sequence of distributions generated from the initial distribution q 0 through the application of the SDE model governed by the action v.In this context, the trajectory inference problem can be formulated as a constrained optimization problem, subject to : Remark 3. Consider the scenario where one aims to construct a road connecting two points A and B, while also passing through multiple intermediate checkpoints.In Eq 4, the SDE constraint dX t = v(X t , t)dt + σdW t models the shape of the road.The initial, terminal, and intermediate constraints can be seen as analogous to the endpoints A and B and the intermediate checkpoints, respectively.Furthermore, the cost functional J in the optimization problem corresponds to the cost associated with building the road.Thus, this optimization problem can be viewed as the task of designing a road that efficiently connects all the required checkpoints, while minimizing the overall cost J.
However, previous studies [18,21,26] have argued that imposing equality constraints r t k ¼ q t k can be numerically challenging.These methods satisfy the initial constraint ρ 0 = q 0 by design.For all subsequent constraints at t = t 1 , t 2 , . .., T, they approximate these equality constraints, denoted by r t k � q t k .One example is TrajectoryNet which solves the following optimization problem: subject to : Ideally, the optimal solution to the relaxed problem in Eq 5 should be close to the optimal solution to the constrained problem in Eq 4. Nonetheless, the relaxed problem is prone to aggregation of errors.Moreover, it is important to note that the accuracy of approximating the constraint r t kþ1 � q t kþ1 is dependent on the accuracy of approximating the previous constraint r t k � q t k .As a result, errors in approximation accumulate as t k increases, potentially causing the terminal distribution ρ T to deviate significantly from the true distribution q T .This deviation can lead to erroneous information about the gene expressions of fully differentiated cells within the population.To mitigate this issue, we choose to enforce an exact terminal equality constraint ρ T = q T .By satisfying this constraint, we expect to improve the accuracy of approximations for the intermediate constraints.To fulfill the terminal constraint, we employ a Forward-Backward Stochastic Differential Equations (FBSDE) model proposed by Vargas et al [21].In the following sections, we will describe how we integrate the TrajectoryNet framework and FBSDE to satisfy the terminal equality constraint and approximate the intermediate equality constraints.Specifically, our aim is to solve the following optimization problem: subject to : Remark 4. To better understand our model, imagine a string on a flat surface.The goal is to shape the string in a way that it connects two fixed points (initial and terminal constraints) while passing through specific points on the surface (intermediate constraints).TrajectoryNet holds one end of the string fixed while allowing the other end to roam freely, hoping to be close to the intermediate checkpoints and the terminal end.In contrast, our model fixes both ends of the string, increasing the likelihood of passing through the intermediate checkpoints and providing a better approximation of the intermediate constraints.

Implementation
In order to solve the optimization problem in Eq 6, we need to address the following numerical problems.
Firstly, we need to compute the cost functional J ðv; X 0 Þ, which involves calculating the expectation of a stochastic integral over continuous time with respect to the initial distribution q 0 .To approximate this integral, we adopt time discretization.By discretizing the SDE model on a dense set of time points between 0 and T, we can generate a finite number of simulated data to estimate the expectation.In particular, since the cost functional represents the expected cost of trajectories originating from initial points distributed according to q 0 , we sample m cells from the observations at t = 0 and employ a discrete-time approximation of the SDE model to simulate m trajectories.
Secondly, we need to estimate the true distribution q t and the generated distribution ρ t at all time points.This requires computing the probability density functions (PDFs) at the selected time points used in the stochastic integral approximation.To accomplish this, we employ kernel density estimation of the empirical distributions.
Thirdly, we need to approximate the intermediate constraints r t k � q t k .In order to ensure that the simulated distributions closely match the empirical distributions at intermediate time points, we introduce a new penalty term along with the ones proposed by TrajectoryNet.
Lastly, it is crucial to satisfy the terminal equality constraint that ensures the agreement between the true distribution ρ T and the simulated distribution q T at the final time point T. To enforce this constraint, we incorporate the Forward-Backward Stochastic Differential Equations (FBSDE) model.
Additionally, in our notation, we assume the observations X t are collected at K + 1 observed time points: 0 = t 0 < t 1 < t 2 < . . .< t K−1 < t K = T.At each time point t k , we have n t k observations, denoted by x t k ;1 ; x t k ;2 ; . . .; x t k ;n t k .
Estimation of cost functional.To estimate the cost functional, which represents the expectation of a stochastic integral, a discretization approach is employed.The integral of a function over continuous time is approximated by summing the function values multiplied by the interval lengths over numerous small intervals.Therefore, we need to generate a set of discrete time points for this purpose.
To be more specific, to account for the constraints imposed by the optimization problem in Eq 6 at the observed time points {t k } 0�k�K , the set of discrete time points U is constructed by take a union of the observed time points {t k } 0�k�K and a denser set of equally spaced time points ft 0 l g 0�l�L 0 with t 0 0 ¼ 0 and t 0 L 0 ¼ T. With a bit abuse of notation, we denote This incorporation of observed time points ensures that the discretization captures the necessary information at these specific instances.
Remark 5. Consider, for example, a scenario where T = 1 and the observed time points are 0.33 and 0.66.In this case, the set U can be constructed by first merging the sets {0, 0.05, 0.1, . .., 0.95, 1} and {0.33, 0.66}.The resulting set is then arranged in ascending order.Our empirical analysis of three datasets indicates that dividing the interval into approximately 200 subintervals tends to be adequate.Nonetheless, we advise users of our model to tailor the number of intervals to the specific characteristics of their data.Factors to consider include the quantity of observed time points and the overall data complexity.Moreover, it is beneficial to normalize the interval [0, T] to [0, 1] by applying a scaling factor of 1/T.Additionally, the expectation is approximated by averaging m simulated paths generated through the approximation of the stochastic differential equation (SDE) dX t = v(X t , t)dt + σdW t .The simulation process begins with the selection of an initial set of m cells with gene expressions x * 0;j for 1 � j � m, from the set of initial observations fx 0;i g 1�i�n 0 .Based on the definition of U as the set of time points, the interval length D l ¼ t 0 l À t 0 lÀ 1 is assigned for 1 � l � L. The simulations are generated using the Euler-Maruyama scheme as follows [21]: Here, x * t 0 l ;j represents the gene expression of the j-th simulated cell at time t 0 l for 1 � j � m and 0 � l � L. The noise terms � t 0 lÀ 1 ;j are independent standard normal random variables.Eq 7 approximates the constraint dX t = v(X t , t)dt + σdW t .It is important to note that the accuracy of the approximation in Eq 7 improves as the number of simulated time points L increases.
The outlined approximation scheme produces m simulated paths.In order to assess the stochastic integral, it becomes necessary to estimate the probability density function (PDF) ρ t for each t 2 U. Furthermore, it is essential to estimate q t to ensure that the constraints ρ t � q t are satisfied at the observed time points t.
In particular, we can not obtain the true density functions r t k and q t k , and we need to replace them by their estimates rt k and qt k in solving the optimization problem in Eq 5. To estimate the true density function q t k and r t k , we use the Nadaraya-Watson density estimator.Based on the observations x t k ;1 ; x t k ;2 ; . . .; x t k ;n t k , we estimate q t k by where H is a diagonal matrix with diagonal entries h.The bandwidth h is tuned by cross validation.
Similarly, based on the simulated points fx * t 0 l ;j : 0 � l � L; 1 � j � mg, we can estimate the distribution of the simulated points at time t 0 l by Again, the same kernel function K H is used to compute the empirical distributions.
To summarize, we employ the discretized SDEs and perform kernel density estimations to simulate multiple trajectories and estimate the expected value of the stochastic integral.This allows us to approximate the cost functional.Specifically, for a given action v, the cost functional is computed as the average of the following terms over the simulated trajectories: Approximation of intermediate constraints.In order to approximate the intermediate constraints rt k ¼ qt k , we introduce an additional penalty term along with the penalties proposed by TrajectoryNet [18].The purpose of this new penalty term is to enhance the recall of the model, as demonstrated in Fig 2b .We start by introducing some notations.Let The set D p t k ;j is the set of distances between the jth simulation at t k and all observations at t k .The set D r t k ;i is the set of distances between the ith observation at t k and all simulations at t k .Furthermore, we order both sets in ascending order such that d pðzÞ t k ;j denotes the zth smallest element in D p t k ;j and d rðzÞ t k ;i denotes the zth smallest element in D r t K ;i .Subsequently, we implement the following penalty function g k at observed time point t k , Here, α 1 , α 2 , and α 3 are positive constants used for regularization.The first two penalty terms in the penalty function are proposed by TrajectoryNet.The former term applies the maximum likelihood function to the simulations, computed from empirical distributions.It aims to encourage the simulations to closely match the empirical distribution of the observations at time t k .This term ensures that the simulations capture the statistical characteristics of the observed data.
The second term aims to improve the precision of the model by penalizing large distances between the simulated data points and observations.It pushes the simulations to be closer to the observed data points, promoting a better alignment between the simulated trajectories and the actual observations.By minimizing this penalty term, the model becomes more precise in capturing the observed trajectory patterns.
We introduce the third penalty term to enhance the recall of the model.It ensures that a larger number of observations are included in a close neighbourhood around the simulations.This term is necessary when the simulated data points are only close to a few observations, resulting in the first penalty term being close to zero, while the distribution of simulated data points, rt k , is not close to the empirical distribution qt k (see Fig 2a and 2b).
With the combination of these three penalty terms, we aim to improve both precision and recall, leading to more accurate trajectory inference.
Satisfying terminal constraint.This section presents a detailed implementation of Forward-Backward Stochastic Differential Equations (FBSDE) in order to achieve the terminal constraint rT ¼ qT mentioned earlier.
In the continuous setting, an FBSDE model is a system of two stochastic differential equations where X t is the forward process and Y s is the backward process such that, where fW f t g 0�t�T and fW b s g 0�s�T are two independent standard Brownian motions.The stochastic processes {X t } 0�t�T and {Y s } 0�s�T in R d with constant volatility σ are involved in the FBSDE model.The process {X t } 0�t�T is equivalent to the trajectory expressed in Eq 1, describing the gene expression changes of a cell over time.The forward drift v specifies the current gene expression of a cell and current time to the direction in which the cell changes its gene expression at that time.The backward process {Y s } 0�s�T reverses the sequence {X 0 , X 1 , . .., X T }.In other words, if the distribution of gene expressions of fully differentiated cells is ρ T at terminal time T, the initial distribution ρ 0 at time 0 can be obtained by applying the backward process to the random variable X T .Finally, the mathematical connection between the forward drift v and the backward drift u is stated in the following theorem.
vðx; tÞ À uðx; T À tÞ ¼ sr log x r t ðxÞ ð9Þ Theorem 1 establishes a relationship between the forward drift v for a cell with gene expression x at time t and the backward drift u for a cell with the same gene expression x at time T − t.Based on this relationship, Vargas et al. [21] proposed an algorithm to solve a simplified optimization problem described as follows: subject to : However, since this problem only considers the initial and terminal constraints, we need to incorporate penalty terms, as discussed in Approximation of intermediate constraints, to approximate the intermediate constraints presented in our main optimization problem described as follows: The algorithm proposed by Vargas et al. [21], as shown in Fig 1, consists of an iterative procedure involving two optimization problems: a forward optimization problem and a backward optimization problem.The forward optimization problem focuses on enforcing the constraint ρ 0 = q 0 and approximating the remaining constraints r t k � q t k for all k = 1, 2, . .., K. Conversely, the backward optimization problem focuses on enforcing the constraint ρ T = q T and approximating the constraints r t k ¼ q t k for k = 0, 1, 2, . .., K − 1.Both optimization problems incorporate a new penalty term.Specifically, the forward optimization problem penalizes the discrepancy between simulations generated by the forward drift and references generated by the solution to the backward optimization problem.Similarly, the backward optimization problem penalizes the discrepancy between simulations generated by the backward drift and references generated by the solution to the forward optimization problem.
The forward optimization problem is formulated as follows.Let u 0 be the backward drift, and let x 0 0;j be the initial points sampled from observations x 0,i .We define the m reference paths as: where x 0 t 0 l ;j represents the j-th reference at time t 0 l .The corresponding forward optimization problem becomes: ! subject to : The corresponding forward optimization problem, denoted by Eq 13, aims to minimize the cost functional Jðq 0 Þ with respect to the action v.
Note that the objective function of the forward optimization problem includes an additional penalty term (with regularization constant β f tuned by cross validation) to account for the divergence between the simulations x * t 0 l;j and the reference paths x 0 t 0 l ;j .This term penalizes the differences between the trajectories generated by the optimal forward drift and the reference paths.Additionally, the terminal equality constraint rT ¼ qT is converted into the penalty term g K in the objective function.
In summary, the first three terms in the objective function are similar to the mean field problem described in Eq 11, with an additional term that quantifies the differences between the simulations generated by the optimal forward drift and the reference paths generated by the backward drift.
The backward optimization problem is defined as follows.Let v 0 be the forward drift, and let x 0 T;j be the terminal points sampled from observations x T,i .we define the m reference paths as: where x 0 t 0 lÀ 1;j represents the j-th reference path at time t 0 lÀ 1 .
The corresponding backward optimization problem becomes: ! subject to : Similar to the forward optimization problem, the equality constraint r0 ¼ q0 is converted into the penalty term g 0 .The objective function of the backward optimization problem includes the first three terms, which are the same as the mean field problem described in Eq 11.Additionally, there is an extra term that quantifies the differences between the simulations generated by the optimal backward drift and the references generated by the forward drift.This term captures the discrepancies between the backward simulations and the forward references.
To summarize, the proposed algorithm by Vargas et al. [21] iterates between the forward and backward optimization problems, where the optimal forward drift generates references for the computation of the optimal backward drift, and vice versa.In each iteration, penalty terms are used to penalize the differences between simulations generated by the optimal drift in one direction and references generated by the drift in the opposite direction.The algorithm has been shown to converge [21,22,26,28], with the optimal forward drift satisfying the terminal constraint rT ¼ qT and the optimal backward drift satisfying the initial constraint r0 ¼ q0 .This implies that the optimal forward drift obtained from this iterative process is indeed the solution to the original optimization problem in Eq 11, satisfying both the initial and terminal constraints.
To prepare for the computational algorithm, we will parameterize the drift functions v and u using neural networks that will be denoted as φ and ψ, respectively.Specifically, we will initialize the parameters in these neural networks so that they can be trained to approximate the desired drift functions.Both neural networks will be fully connected, with default setting of three hidden layers each containing 128 units, and they will use ReLU activation functions throughout.

Algorithm 1 Solving FBSDE Model in Eq 11
Initialize: � 0 , ψ 0 1: N ( 1 2: while N < N max do 3: Sample fx 0 t 0 l ;j g 1�j�m;0�l�L with u(y, s; ψ N−1 ) and Eq 12 4: � ( � N−1 5: while not converged do 6: Simulate fx * t 0 l ;j g 1�j�m;0�l�L with v(x, t; �) 7: Compute Jðv; q 0 Þ in Eq 13 8: � ( StochasticGradientDescentðr � Jðv; q0 ÞÞ 9: end while 10: � N ( � 11: Sample fx 0 t 0 l ;j g 1�j�m;0�l�L with v(y, s; � N ) and Eq 14 12: ψ ( ψ N−1 13: while not converged do 14: Simulate fx * t 0 l ;j g 1�j�m;0�l�L with u(x, t; ψ) 15: Compute Jðu; q 0 Þ in Eq 15 16: � ( StochasticGradientDescentðr � Jðu; q 0 ÞÞ 17: end while 18: ψ N ( ψ 19: N ( N + 1 20: end while 21: return �, ψ Parameter tuning.In alignment with TrajectoryNet methodology, we conducted a grid search for optimal parameter values, testing α F , α 1 , α 2 , α 3 2 {1, 5, 10} and σ 2 {0.01, 0.1, 1}, with performance evaluated on a validation set.We found that for the volatility term σ, values between 0.01 and 0.1 are effective for normalized data, as higher values lead to "blurry" trajectories due to excessive noise, while values approaching zero render the model near-deterministic, similar to a Partial Differential Equation (PDE) model, rather than an SDE model.Remark 6.Similar to TrajectoryNet, our FBSDE model framework has the capacity to incorporate a velocity penalty term.In our extensive parameter validation studies, we observed that including this velocity penalty did not markedly influence the performance of the FBSDE model across the datasets we examined.This minimal impact may be specific to the unique attributes of our selected datasets.It is conceivable that other datasets with distinct characteristics could benefit more substantially from the incorporation of a velocity penalty.Recognizing the potential variability in utility, the developers of TrajectoryNet have provided the velocity penalty as an optional feature in their code repository.This choice reflects the term's adaptable relevance depending on the dataset being analyzed.In alignment with this adaptable methodology, we have similarly implemented the velocity penalty as an optional element in our FBSDE model codebase.This allows users the flexibility to activate it when analyzing datasets where RNA velocity considerations are deemed important.The key innovation of our approach over previously published methods such as TrajectoryNet [18] is that we model the population growth in both forward and backward directions.As such, the FBSDE framework consists of two parts, the forward model and the backward model; the forward model is essentially the same as Trajec-toryNet.It has been widely observed that using Forward model alone may not always satisfy the terminal equality constraints if the true trajectory has many non-linear segments [21].This can be attributed to the lack of closed form solutions in modelling stochastic differential equations [22].In other words, the stochastic differential equations cannot be solved analytically.As an alternative, numerical approximations sometimes can alleviate these difficulties, but their performances are not always satisfactory.By including a Backward model in the FBSDE framework, one can convert the terminal constraint rT ¼ q T to an initial constraint for the Backward model which can be numerically satisfied [21].Furthermore, the iteration between the Forward model and the Backward model has been proven to converge to an optimal solution where both the initial constraint r0 ¼ r 0 and the terminal constrain rT ¼ r T are both satisfied [28].

Overview of the FBSDE model
As alluded to in the Introduction, optimal transport (OT) has been recently adopted in modelling cell growth trajectory and achieved promising performance [11,12].Typically, OT is more effective in modelling linear trajectories, whereas SDE can model non-linear  circlea-d, the paths display greater angular rotation as the model is trained to accommodate higher levels of entropy.This is reflected in the curly shape of the paths, which allows the population to remain denser for longer periods.In contrast, optimal transport employs linear path interpolations as shown in Fig 3e.

Performance on a human embryonic stem cells dataset
We compared the FBSDE model with WOT and TrajectoryNet by visually examining the generative power using two-dimensional human embryonic stem cells data obtained from Moon et al [23].The dataset contained 16, 825 cells grown as embryoid bodies, which were sampled from 10 time points over a period of 27 days at 3-day intervals.All three models were provided with the same starting population of cells at the initial time point (t = 0) and generated samples at 200 equally spaced time points in addition to the already observed time points.

Performance on a mouse embryonic fibroblasts dataset
We further evaluated the generative power of FBSDE, WOT, and TrajectoryNet using a mouse embryonic fibroblasts dataset obtained from Schiebinger et al [11].This dataset contains 259, 155 cells sampled at 36 time points over an 18-day period.FBSDE is a mesh-free method, which means that once a model has been fitted, simulations can be generated without relying on observations.This demonstrates that the FBSDE model has the ability to capture not only smooth but also sharp shape features of developmental trajectories.The successful performance of FBSDE on this dataset provides further evidence of its effectiveness in accurately modelling complex biological processes.Furthermore, a thorough evaluation of the numerical performances of these three methods will be conducted using cross-validation, which will be discussed in detail in Cross-validation study.

Performance on a Arabidopsis thaliana stem cells dataset
Finally, we compared the performance of FBSDE, WOT, and TrajectoryNet on the Arabidopsis thaliana stem cells dataset generated by Shahan and colleagues [30].This dataset contains    simulations in Fig 6d resemble the observations more closely, with a few recognizable branches and almost no unrealistic long-range crossings seen.In summary, none of the models perform well on the Arabidopsis thaliana stem cells dataset.However, the WOT model retains shape at observed time points, while the FBSDE model is capable of distinguishing some developmental branches.

Cross-validation study
In this study, we performed a comprehensive benchmarking of the FBSDE method against two other prominent approaches, Waddington-OT (WOT) and TrajectoryNet.To evaluate their performances, we utilized cross-validation techniques on several well-established singlecell datasets.
For each validation set, we considered a total of n + 1 observed points labeled from t 0 to t n .In the validation process, we randomly selected 20% of the time points as the test data, while the remaining time points were used for training.The model then interpolated the data from the selected 20% time points, as depicted in Fig 2, and compared the interpolated data with the ground truth (i.e., the test data).This process was repeated for a total of M = 50 times to ensure robustness, and the average performance metric across all repetitions was reported.
To assess the accuracy of interpolation, we chose the Wasserstein Distance W 1 as the standard metric to measure the divergence between the interpolated data and the test data.
Table 1 offers a comprehensive comparison of the numerical performance of three models across various datasets.Additionally, we assess a variant of our model, named FBSDE_minus.This variant maintains the exact terminal equality constraint but does not include our novel penalty term, which is designed to improve intermediate constraint approximation.We also compare our results with the Stationary-OT method, as adapted by Shahan et al. [12,30], which modifies the regularization term for better numerical stability.
For the human embryonic stem cells dataset, both the WOT and our FBSDE model demonstrate comparable accuracy in aligning with the empirical distributions.In the Arabidopsis thaliana stem cells dataset, the WOT model shows higher precision (defined in Fig 3a), primarily because its simulations are based on observational data.However, the generative strengths of SDE models like TrajectoryNet and FBSDE become more prominent when simulating a larger number of time points, as seen in the mouse embryonic fibroblasts dataset.Here, our FBSDE model notably outperforms both WOT and TrajectoryNet.The comparison between FBSDE_minus and the full FBSDE model reveals a substantial gap in performance.The FBSDE_minus variant, lacking the new penalty term, exhibits significantly reduced accuracy in trajectory inference, highlighting the importance of this feature in our complete model.Additionally, the introduction of the TrajectoryNet_plus variant, which incorporates an additional recall penalty term while omitting the terminal equality constraint, results in an observable enhancement in accuracy relative to the original TrajectoryNet model.This improvement highlights the significance of the recall penalty term.Importantly, the integration of the additional penalty term with the satisfaction of the terminal constraint in the FBSDE model significantly enhances model accuracy.Finally, the Stationary-OT method displays a performance level similar to that of Waddington-OT.
In conclusion, our cross-validation study indicates that, for the three datasets tested, the FBSDE model stands out among current methods in single cell trajectory inference.

Discussion
In this paper, we present a new method called Forward-Backward Stochastic Differential Equations (FBSDE) to model the developmental trajectories of single cells and compare it to two other recent methods-Waddington-OT and TrajectoryNet.The FBSDE method has several advantages.First, the use of stochastic differential equations allows for the generative modelling of developmental trajectories on a continuous space-time manifold.Second, iterations between the forward SDE and the backward SDE converge to a solution that satisfies both the initial and terminal probability distribution constraints leading to higher accuracy in trajectory modelling.Third, the FBSDE method integrates concepts from mean field theory to mimic the effects of cell-to-cell interactions allowing for more realistic biological modelling.
We apply all three methods on several scRNA-seq data sets.In summary, the FBSDE method outperforms the other two methods significantly.In situations where accurate modelling is difficult to achieve (e.g., Arabidopsis thaliana stem cells dataset), the FBSDE method still retains a fairly high level of similarity to the ground truth.
Although FBSDE demonstrates great promise and success in generative modelling of single cell developmental trajectories, there exist a few limitations.Firstly, our model is applied to a two-dimensional dataset, which has been processed and dimensionally reduced from an initial dimensionality of up to 20,000 genes.While the model is promising in low dimensions, the curse of high dimensionality should not be overlooked in cases where a large number of genes cannot be reduced in a dataset.Recent works by Chen et al [31] and Liu et al [32] aim to address such a problem where each gene of each cell is modelled by a separate stochastic differential equation.In addition, Qiu et al [33] proposed to incorporate domain knowledge while modelling these differential equations.Secondly, the interactions currently implemented in the FBSDE model do not cover a wide range of biological interactions.Despite their effectiveness, more complex interactions that model typical biological interactions such as ligand-signalling should be further investigated.Thirdly, it should be noted that the FBSDE model relies on the structure of the implemented neural network, and further exploration of different network architectures is warranted.Alternative forms of neural networks, such as the neural ODE proposed by Chen et al. [34] and recurrent neural networks, may have the potential to improve the accuracy of generative modelling and should be considered for future investigation.
Our FBSDE method introduces a novel aspect of modelling single cell developmental trajectories.There are a few improvements that can be made in future directions.First, instead of a single SDE, the FBSDE model can incorporate a mixture of stochastic differential equations analogous to Gaussian mixture models in probability density estimation.The introduction of the mixture model has the potential to increase model accuracy in situations such as the Arabidopsis thaliana stem cells dataset.Second, the FBSDE method can be generalized to systemic biology by incorporating spatial data into the genetic data.The introduction of spatial information allows for the modelling of cell-to-cell interactions that require physical proximity.

Fig 1 .
Fig 1. Overview of FBSDE model.Colours in the scatter plots represent time where red is the earliest and blue is the latest time point.In the forward model, the population is modelled from left to right.In the backward model, the population evolves from right to left.The blue rectangle on top represents the TrajectoryNet framework by Tong et al [18] which is equivalent to half of one iteration of our FBSDE model.The FBSDE model iterates between the Forward and Backward models; traversing through the Forward model generates new simulated data points which are subsequently used as training set by the backward model and vice versa.https://doi.org/10.1371/journal.pcbi.1012015.g001

Fig 2 .
Fig 2. Fig 2a and 2b) Comparison between precision and recall.Red circles represent observations and green circles represent simulations.In Fig 2a, a high precision model means that simulations are close to observations; however, there can be some observations without any simulations nearby.In Fig 2b, a high recall model means that most observations have some simulations nearby; however, there can be some simulations that are far away from any observations.Fig 2c and 2d) Comparison between OT and SDE.Filled circles represent cells from t = 0. Unfilled circles represent cells from t = 1.The red colour represents observations.The green colour represents simulations.Triangles indicated interpolated cells at time t = h.Fig 2c) OT infers the endpoints of a path and the trajectory is simply a straight line connecting the endpoints.OT can only infer paths whose endpoints come from observations.Fig 2d) SDE infers both linear and non-linear trajectories using differential equations.In addition, SDE also infers paths originating from points that are not observations (i.e., given an arbitrary starting point).https://doi.org/10.1371/journal.pcbi.1012015.g002

Fig 1
Fig 1 provides an overview of the FBSDE.The key innovation of our approach over previously published methods such as TrajectoryNet[18] is that we model the population growth in both forward and backward directions.As such, the FBSDE framework consists of two parts, the forward model and the backward model; the forward model is essentially the same as Trajec-toryNet.It has been widely observed that using Forward model alone may not always satisfy the terminal equality constraints if the true trajectory has many non-linear segments[21].This can be attributed to the lack of closed form solutions in modelling stochastic differential equations[22].In other words, the stochastic differential equations cannot be solved analytically.As an alternative, numerical approximations sometimes can alleviate these difficulties, but their performances are not always satisfactory.By including a Backward model in the FBSDE framework, one can convert the terminal constraint rT ¼ q T to an initial constraint for the Backward model which can be numerically satisfied[21].Furthermore, the iteration between the Forward model and the Backward model has been proven to converge to an optimal solution where both the initial constraint r0 ¼ r 0 and the terminal constrain rT ¼ r T are both satisfied[28].As alluded to in the Introduction, optimal transport (OT) has been recently adopted in modelling cell growth trajectory and achieved promising performance[11,12].Typically, OT is more effective in modelling linear trajectories, whereas SDE can model non-linear trajectories.Fig 2c and 2d illustrate the differences between OT and SDE.An OT model connects the endpoints with a straight line.In contrast, an SDE model infers non-linear trajectories connecting the endpoints.In addition, an OT model (Fig 2c) can only infer trajectories whose endpoints are observations whereas an SDE model (Fig 2d) can generate trajectories originating from any arbitrary starting point.Fig 2c and 2d display the trajectories of the cells from time t = 0 to t = 1 and demonstrate how to interpolate their gene expression at some intermediate time point t = h, where 0 < h < 1.The generative modelling differences between OT and SDE are further illustrated in Fig 3. When given a set of initial and terminal distributions, OT can only predict a linear trajectory, as depicted in Fig 3e.On the other hand, SDE can infer trajectories with various shapes depending on the parameter choices, as shown in Fig fig:circlea-d.Specifically, the shapes of the trajectories are affected by the extent to which the population accommodates high entropy (i.e. more crowded), as indicated in this example.As shown in Fig fig: Fig 4 illustrates the comparison among the ground truth distribution and samples generated by all three models, demonstrating that our FBSDE model generated samples that more closely resemble the observations from human embryonic stem cells data.Furthermore, Fig 5b highlights why linear interpolations can fail in trajectory inference.The WOT model interpolated some points mainly in the upper portion of the figure that were not seen in real observations.Fig 5c and 5d show that both TrajectoryNet and FBSDE generated simulations that seemed to match the observations at the terminal time point t = T.In addition, both TrajectoryNet and FBSDE generate smooth trajectories as seen in the observations (Fig 5a) in contrast to the segmented linear trajectories generated by the WOT model.However, at intermediate time points, the FBSDE generated simulations represented the observations more accurately than TrajectoryNet.Finally, the FBSDE model visually resembled the observations more closely than the other two models, albeit by a small margin.
Fig 5 shows that all three models performed relatively well on this dataset.Notably, the observations in Fig 5a exhibit two distinct sharp turns (indicated by arrows in Fig 5) in the shape, which were effectively captured by both WOT (Fig 5b) and FBSDE (Fig 5d) models but not recognized by TrajectoryNet (Fig 5c).However, unlike WOT, which relies on simulations using observed points (as depicted in Fig 3c), FBSDE utilizes a neural network to fully parameterize the population movement.

Fig 3 .
Fig 3. Difference between FBSDE and Waddington-OT in modelling a simulated dataset.The time is coded in the vertical bars next to each panel; the time points are normalized with 0 and 1 representing starting and end point, respectively.At t = 0, the population follows the standard normal distribution.At t = 1, the population is uniformly distributed on the ring centred at (0, 0) with a radius of 10.At 0 < t < 1, the distribution is interpolated by either the FBSDE or Waddington-OT model.In Fig 3a, Fig 3b, Fig 3c and Fig 3d, the population evolves based on a stochastic differential equation where a neural network parametrizes the drift term.From Fig 3a to Fig 3d, the population exhibits more angular rotation as it favours a higher level of entropy (i.e.density).In Fig 3e, the endpoints for each point at t = 0 are sampled based on the optimal transport map and the interpolation is performed linearly by connecting each pair of points.

Fig 4 .
Fig 4. Performances on a human embryonic stem cells dataset.Fig 4a) Gene expression profiles of single cells are reduced to two dimensions using the PHATE method.The data correspond to a total of 10 time points over 27 days at 3 days intervals.The time points are color coded on a spectrum from red to deep blue.Fig 4b-d) Trajectory inference by three different methods (Waddington-OT, TrajectoryNet and FBSDE).In each experiment, a total of 200 equally spaced time points are added between the starting and end time points.As can be seen, FBSDE provides a trajectory most resembling the ground truth in Fig 4a.
https://doi.org/10.1371/journal.pcbi.1012015.g004110, 427 cells with pseudotime labels ranging from 0 to 50 and exhibits several distinct developmental branches.The use of pseudotime, as established by Shahan et al.[30], provides a meaningful way to encode temporal information, particularly in cases where absolute time points are not available or are less relevant.Fig 6a illustrates the complexity of this dataset, making it the most challenging to model compared to the previous two datasets.Fig 6c shows

Fig 5 .
Fig 5. Performance on a mouse embryonic fibroblasts dataset.Fig 5a) Gene expression of single cells are reduced to two dimensions using the forcedirected layout embedding proposed in the study by Weinreb and colleagues [29].The data correspond to a total of 37 time points over 18 days at 12 hours intervals.The time points are color coded on a spectrum from red to deep blue.Two visible sharp turns are indicated by arrows.Fig 5b-d) Trajectory inference by three different methods (Waddington-OT, TrajectortNet and FBSDE).In each experiment, a total of 200 equally spaced time points are added between the starting and end time points.The same color code is used.As can be seen, all three methods provide trajectories similar to the ground truth in Fig 5a although FBSDE outperforms the others modestly.https://doi.org/10.1371/journal.pcbi.1012015.g005

Fig 6 .
Fig 6.Performance on a Arabidopsis thaliana stem cells dataset.Fig 6a) Gene expression of single cells are reduced to two dimensions using the UMAP embedding method.The data correspond to a total of 51 time points on a pseudotime scale from 0 to 50.Fig 6b-d) Trajectory inference by three different methods (Waddington-OT, TrajectoryNet and FBSDE). in each experiment, a total of 200 equally spaced time points are added between the starting and end time points.The same color code is used.As can be seen.Waddington-OT produces some unrealistic paths among different branches in Fig 6a.TrajectoryNet fails to capture the geometric features of the true trajectories.FBSDE produces some trajectories that resemble the ground truth to some degree.https://doi.org/10.1371/journal.pcbi.1012015.g006